//
//  cnn.c
//  CNN
//
//  Created by fanwenjie on 2016/10/14.
//  Copyright © 2016年 fanwenjie. All rights reserved.
//

#include "cnn.h"
#include <stdlib.h>
#include <math.h>
#include <memory.h>
#define asm __asm
//f(n,align) = min{x >= n && x % align == 0}
#define ALIGN(n,align) (((align)-1+(n))/(align)*(align))

#define sqrt sqrtf

static inline void Vaddps(const pack_t src0, const pack_t src1, pack_t des);

static void ConvoluteValid(const pack_t *src, const pack_t *conv, pack_t *des, const long dh, const long dw, const long ch, const long cw);
static void SubsampMaxForward(const pack_t *src, const pack_t *des, const long sh, const long sw, const long dh, const long dw);
static void VectorXMatrix(const pack_t *src, const pack_t *mat, pack_t *des, const long height, const long width);
static void GetResult(pack_t output[OUTPUT], int labels[SZPACK]);

typedef struct
{
	pack_t layer0[LAYER0][HEIGHT_FEATURE0][WIDTH_FEATURE0];
	pack_t layer1[LAYER1][HEIGHT_FEATURE1][WIDTH_FEATURE1];
	pack_t layer2[LAYER2][HEIGHT_FEATURE2][WIDTH_FEATURE2];
	pack_t layer3[LAYER3][HEIGHT_FEATURE3][WIDTH_FEATURE3];
	pack_t layer4[LAYER4][HEIGHT_FEATURE4][WIDTH_FEATURE4];
	pack_t layer5[LAYER5][HEIGHT_FEATURE5][WIDTH_FEATURE5];
	pack_t output[OUTPUT];
}FeaturePack;

static float Relu(const float x)
{
    return x * (x > 0.f);
}

static float None(const float x)
{
    return x;
}

#define GETLENGTH(array) (sizeof(array)/sizeof(*(array)))

#define FOREACH(i,count) for (long i = 0; i < count; ++i)

#define CONVOLUTE_VALID(input,output,weight)				            \
{									                                    \
    ConvoluteValid((pack_t *)input,(pack_t *)weight,(pack_t *)output,	\
        GETLENGTH(output),GETLENGTH(*(output)),				            \
        GETLENGTH(weight),GETLENGTH(*(weight)));			            \
}





#define CONVOLUTION_FORWARD(input,output,weight,bias,action)					\
{																				\
	FOREACH(x, GETLENGTH(weight))												\
		FOREACH(y, GETLENGTH(*weight))											\
			CONVOLUTE_VALID(input[x], output[y], weight[x][y]);					\
	FOREACH(x, GETLENGTH(output))												\
		FOREACH(y, sizeof(output[x]) / sizeof(pack_t))							\
			Vaddps(((pack_t *)output[x])[y], bias[x], ((pack_t *)output[x])[y]);\
	FOREACH(i, sizeof(output) / sizeof(float))									\
		((float *)output)[i] = action(((float *)output)[i]);					\
}

#define SUBSAMP_MAX_FORWARD(input,output)												\
{																						\
	FOREACH(j, GETLENGTH(output))														\
		SubsampMaxForward((pack_t *)input[j],(pack_t *)output[j],GETLENGTH(*(input)), \
			GETLENGTH(**(input)),GETLENGTH(*(output)),GETLENGTH(**(output)));			\
}


#define DOT_PRODUCT_FORWARD(input,output,weight,bias,action)						\
{																					\
    VectorXMatrix((pack_t *)input,(pack_t *)weight,(pack_t *)output,				\
        GETLENGTH(weight),GETLENGTH(*(weight)));									\
	FOREACH(j, sizeof(bias) / sizeof(float))										\
		((float *)output)[j] = action(((float *)output)[j] + ((float *)bias)[j]);	\
}

static void Forward(const CNN *cnn, FeaturePack *featurePack)
{
	CONVOLUTION_FORWARD(featurePack->layer0, featurePack->layer1, cnn->weight0_1, cnn->bias0_1, Relu);
	SUBSAMP_MAX_FORWARD(featurePack->layer1, featurePack->layer2);
	CONVOLUTION_FORWARD(featurePack->layer2, featurePack->layer3, cnn->weight2_3, cnn->bias2_3, Relu);
	SUBSAMP_MAX_FORWARD(featurePack->layer3, featurePack->layer4);
	CONVOLUTION_FORWARD(featurePack->layer4, featurePack->layer5, cnn->weight4_5, cnn->bias4_5, Relu);
	DOT_PRODUCT_FORWARD(featurePack->layer5, featurePack->output, cnn->weight5_6, cnn->bias5_6, None);
}


static void LoadInput(pack_t(*layer0)[HEIGHT_FEATURE0][WIDTH_FEATURE0], const image_t input,const int pad)
{
	float mean = 0, std = 0;
	const int szinput = sizeof(image_t) / sizeof(uint8_t);
	FOREACH(x, sizeof(image_t) / sizeof(*input))
		FOREACH(y, sizeof(*input) / sizeof(**input))
	{
		mean += input[x][y];
		std += input[x][y] * input[x][y];
	}
	mean /= szinput;
	std /= szinput;
	std = sqrt(std - mean * mean);
	FOREACH(x,sizeof(image_t)/sizeof(*input))
		FOREACH(y, sizeof(*input) / sizeof(**input))
	{
		float value = (input[x][y] - mean) / std;
		FOREACH(i, SZPACK)
		{
			(*layer0)[x + pad][y + pad][i] = value;
		}
	}
}

void CNNPredict(const CNN *cnn,const image_t input, int result[SZPACK])
{
    char *buffer = (char *)malloc(sizeof(FeaturePack) + sizeof(pack_t) - 1);
    memset(buffer, 0 ,sizeof(FeaturePack) + sizeof(pack_t) - 1);
	//char buffer[sizeof(FeaturePack) + sizeof(pack_t) - 1] = { 0 };
	FeaturePack *featurePack = (FeaturePack *)ALIGN((unsigned long long)buffer, sizeof(pack_t));
	LoadInput(featurePack->layer0, input, PAD);
	Forward(cnn, featurePack);
	GetResult(featurePack->output, result);
    free(buffer);
}

static inline void Vaddps(const pack_t src0, const pack_t src1, pack_t des)
{
	asm("							\
		vmovapd (%0), %%ymm0;		\
		vaddps (%1), %%ymm0, %%ymm0;\
		vmovapd %%ymm0, (%2);		\
		"::"r"(src0), "r"(src1), "r"(des)
		: "%ymm0", "memory");
}

static void ConvoluteValid(const pack_t *src, const pack_t *conv, pack_t *des, const long dh, const long dw, const long ch, const long cw)
{
	const long sw = dw + cw - 1;
	for (long d0 = 0; d0 < dh; ++d0)
		for (long d1 = 0; d1 < dw; ++d1)
		{
			pack_t *d = des + d0 * dw + d1;
			asm("vmovaps (%0), %%ymm0;"::"r"(d) : "%ymm0");
			for (long c0 = 0; c0 < ch; ++c0)
				for (long c1 = 0; c1 < cw; ++c1)
				{
					asm("                               \
                        vmovaps (%1), %%ymm1;           \
                        vmulps (%0), %%ymm1, %%ymm1;    \
                        vaddps %%ymm1, %%ymm0, %%ymm0;  \
                        "::"r"(src + (d0 + c0) * sw + d1 + c1), "r"(conv + c0 * cw + c1)
						: "%ymm0", "%ymm1");
				}
			asm("vmovaps %%ymm0, (%0);"::"r"(d) : "%ymm0");
		}
}



static void VectorXMatrix(const pack_t *src, const pack_t *mat, pack_t *des, const long height, const long width)
{
	for (long y = 0; y < width; ++y)
	{
		pack_t *d = des + y;
		asm("vmovaps (%0), %%ymm0;"::"r"(d) : "%ymm0");
		for (long x = 0; x < height; ++x)
		{
			asm("                               \
                vmovaps (%1), %%ymm1;           \
                vmulps (%0), %%ymm1, %%ymm1;    \
                vaddps %%ymm1, %%ymm0, %%ymm0;  \
                "::"r"(src + x), "r"(mat + x * width + y)
				: "%ymm0", "%ymm1");
		}
		asm("vmovaps %%ymm0, (%0);"::"r"(d) : "%ymm0");
	}
}

static void SubsampMaxForward(const pack_t *src, const pack_t *des, const long sh, const long sw, const long dh, const long dw)
{
	const long lh = sh / dh, lw = sw / dw;
	for (long d0 = 0; d0 < dh; ++d0)
		for (long d1 = 0; d1 < dw; ++d1)
		{
			asm("vmovaps (%0), %%ymm0;"::"r"(src + d0 * lh * sw + d1 * lw) : "%ymm0");
			for (long l = 1; l < lh * lw; ++l)
				asm("vmaxps (%0), %%ymm0, %%ymm0"::"r"(src + (d0 * lh + l / lw) * sw + d1 * lw + l % lw) : "%ymm0");
			asm("vmovaps %%ymm0, (%0);"::"r"(des + d0 * dw + d1) : "%ymm0");
		}
}

static void GetResult(pack_t output[OUTPUT], int result[SZPACK])
{
    asm("                               \
        vxorps %%ymm1, %%ymm1, %%ymm1;  \
        vmovups %%ymm1, (%0);           \
        vmovaps (%1), %%ymm0;           \
        "::"r"(result),"r"(output)
        :"%ymm0", "%ymm1");
	for (int j = 1; j < OUTPUT; ++j)
	{
        asm("									\
            vmovaps (%2), %%ymm1;               \
            vcmpltps %%ymm1, %%ymm0, %%ymm2;	\
            vmaxps %%ymm1, %%ymm0, %%ymm0;		\
            vbroadcastss %0, %%ymm1;			\
            vmaskmovps %%ymm1, %%ymm2, (%1);	\
            "::"m"(j), "r"(result), "r"(output + j)
            : "%ymm0", "%ymm1", "%ymm2", "memory");
	}
}
